// LLVM implements part of the ABI lowering internally, but also requires that
// the user pack and unpack values themselves sometimes. This can help the LLVM
// optimizer by exposing operations to the optimizer, but it requires that the
// frontend know exactly what IR to produce in order to get the right ABI.
//
// So far, this is an implementation of the SysV AMD64 ABI.

#![deny(
    dead_code,
    missing_docs,
)]

use crate::translator::intrinsics::{type_to_llvm, Intrinsics};
use inkwell::{
    attributes::{Attribute, AttributeLoc},
    builder::Builder,
    context::Context,
    types::{BasicType, FunctionType, StructType},
    values::{
        BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue,
        PointerValue, VectorValue,
    },
    AddressSpace,
};
use wasmer_compiler::CompileError;
use wasmer_types::{FunctionType as FuncSig, Type};

// Given a function definition, retrieve the parameter that is the vmctx pointer.
pub fn get_vmctx_ptr_param<'ctx>(func_value: &FunctionValue<'ctx>) -> PointerValue<'ctx> {
    func_value
        .get_nth_param(
            if func_value
                .get_enum_attribute(
                    AttributeLoc::Param(0),
                    Attribute::get_named_enum_kind_id("sret"),
                )
                .is_some()
            {
                1
            } else {
                0
            },
        )
        .unwrap()
        .into_pointer_value()
}

// Given a wasm function type, produce an llvm function declaration.
pub fn func_type_to_llvm<'ctx>(
    context: &'ctx Context,
    intrinsics: &Intrinsics<'ctx>,
    sig: &FuncSig,
) -> Result<(FunctionType<'ctx>, Vec<(Attribute, AttributeLoc)>), CompileError> {
    let user_param_types = sig.params().iter().map(|&ty| type_to_llvm(intrinsics, ty));

    let param_types =
        std::iter::once(Ok(intrinsics.ctx_ptr_ty.as_basic_type_enum())).chain(user_param_types);

    let sig_returns_bitwidths = sig
        .results()
        .iter()
        .map(|ty| match ty {
            Type::I32 | Type::F32 => 32,
            Type::I64 | Type::F64 => 64,
            Type::V128 => 128,
            Type::ExternRef => unimplemented!("externref in the llvm backend"),
            Type::FuncRef => unimplemented!("funcref in the llvm backend"),
        })
        .collect::<Vec<i32>>();

    Ok(match sig_returns_bitwidths.as_slice() {
        [] => (
            intrinsics
                .void_ty
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [_] => {
            let single_value = sig.results()[0];
            (
                type_to_llvm(intrinsics, single_value)?
                    .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
                vec![],
            )
        }
        [32, 64] | [64, 32] | [64, 64] => {
            let basic_types: Vec<_> = sig
                .results()
                .iter()
                .map(|&ty| type_to_llvm(intrinsics, ty))
                .collect::<Result<_, _>>()?;

            (
                context
                    .struct_type(&basic_types, false)
                    .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
                vec![],
            )
        }
        [32, 32] if sig.results()[0] == Type::F32 && sig.results()[1] == Type::F32 => (
            intrinsics
                .f32_ty
                .vec_type(2)
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [32, 32] => (
            intrinsics
                .i64_ty
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [32, 32, _] if sig.results()[0] == Type::F32 && sig.results()[1] == Type::F32 => (
            context
                .struct_type(
                    &[
                        intrinsics.f32_ty.vec_type(2).as_basic_type_enum(),
                        type_to_llvm(intrinsics, sig.results()[2])?,
                    ],
                    false,
                )
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [32, 32, _] => (
            context
                .struct_type(
                    &[
                        intrinsics.i64_ty.as_basic_type_enum(),
                        type_to_llvm(intrinsics, sig.results()[2])?,
                    ],
                    false,
                )
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [64, 32, 32] if sig.results()[1] == Type::F32 && sig.results()[2] == Type::F32 => (
            context
                .struct_type(
                    &[
                        type_to_llvm(intrinsics, sig.results()[0])?,
                        intrinsics.f32_ty.vec_type(2).as_basic_type_enum(),
                    ],
                    false,
                )
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [64, 32, 32] => (
            context
                .struct_type(
                    &[
                        type_to_llvm(intrinsics, sig.results()[0])?,
                        intrinsics.i64_ty.as_basic_type_enum(),
                    ],
                    false,
                )
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        [32, 32, 32, 32] => (
            context
                .struct_type(
                    &[
                        if sig.results()[0] == Type::F32 && sig.results()[1] == Type::F32 {
                            intrinsics.f32_ty.vec_type(2).as_basic_type_enum()
                        } else {
                            intrinsics.i64_ty.as_basic_type_enum()
                        },
                        if sig.results()[2] == Type::F32 && sig.results()[3] == Type::F32 {
                            intrinsics.f32_ty.vec_type(2).as_basic_type_enum()
                        } else {
                            intrinsics.i64_ty.as_basic_type_enum()
                        },
                    ],
                    false,
                )
                .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
            vec![],
        ),
        _ => {
            let basic_types: Vec<_> = sig
                .results()
                .iter()
                .map(|&ty| type_to_llvm(intrinsics, ty))
                .collect::<Result<_, _>>()?;

            let sret = context
                .struct_type(&basic_types, false)
                .ptr_type(AddressSpace::Generic);

            let param_types = std::iter::once(Ok(sret.as_basic_type_enum())).chain(param_types);

            (
                intrinsics
                    .void_ty
                    .fn_type(&param_types.collect::<Result<Vec<_>, _>>()?, false),
                vec![(
                    context.create_enum_attribute(Attribute::get_named_enum_kind_id("sret"), 0),
                    AttributeLoc::Param(0),
                )],
            )
        }
    })
}

// Marshall wasm stack values into function parameters.
pub fn args_to_call<'ctx>(
    alloca_builder: &Builder<'ctx>,
    func_sig: &FuncSig,
    ctx_ptr: PointerValue<'ctx>,
    llvm_fn_ty: &FunctionType<'ctx>,
    values: &[BasicValueEnum<'ctx>],
) -> Vec<BasicValueEnum<'ctx>> {
    // If it's an sret, allocate the return space.
    let sret = if llvm_fn_ty.get_return_type().is_none() && func_sig.results().len() > 1 {
        Some(
            alloca_builder.build_alloca(
                llvm_fn_ty.get_param_types()[0]
                    .into_pointer_type()
                    .get_element_type()
                    .into_struct_type(),
                "sret",
            ),
        )
    } else {
        None
    };

    let values = std::iter::once(ctx_ptr.as_basic_value_enum()).chain(values.iter().copied());

    if let Some(sret) = sret {
        std::iter::once(sret.as_basic_value_enum())
            .chain(values)
            .collect()
    } else {
        values.collect()
    }
}

// Given a CallSite, extract the returned values and return them in a Vec.
pub fn rets_from_call<'ctx>(
    builder: &Builder<'ctx>,
    intrinsics: &Intrinsics<'ctx>,
    call_site: CallSiteValue<'ctx>,
    func_sig: &FuncSig,
) -> Vec<BasicValueEnum<'ctx>> {
    let split_i64 = |value: IntValue<'ctx>| -> (IntValue<'ctx>, IntValue<'ctx>) {
        assert!(value.get_type() == intrinsics.i64_ty);
        let low = builder.build_int_truncate(value, intrinsics.i32_ty, "");
        let lshr =
            builder.build_right_shift(value, intrinsics.i64_ty.const_int(32, false), false, "");
        let high = builder.build_int_truncate(lshr, intrinsics.i32_ty, "");
        (low, high)
    };

    let f32x2_ty = intrinsics.f32_ty.vec_type(2).as_basic_type_enum();
    let extract_f32x2 = |value: VectorValue<'ctx>| -> (FloatValue<'ctx>, FloatValue<'ctx>) {
        assert!(value.get_type() == f32x2_ty.into_vector_type());
        let ret0 = builder
            .build_extract_element(value, intrinsics.i32_ty.const_int(0, false), "")
            .into_float_value();
        let ret1 = builder
            .build_extract_element(value, intrinsics.i32_ty.const_int(1, false), "")
            .into_float_value();
        (ret0, ret1)
    };

    let casted = |value: BasicValueEnum<'ctx>, ty: Type| -> BasicValueEnum<'ctx> {
        match ty {
            Type::I32 => {
                assert!(
                    value.get_type() == intrinsics.i32_ty.as_basic_type_enum()
                        || value.get_type() == intrinsics.f32_ty.as_basic_type_enum()
                );
                builder.build_bitcast(value, intrinsics.i32_ty, "")
            }
            Type::F32 => {
                assert!(
                    value.get_type() == intrinsics.i32_ty.as_basic_type_enum()
                        || value.get_type() == intrinsics.f32_ty.as_basic_type_enum()
                );
                builder.build_bitcast(value, intrinsics.f32_ty, "")
            }
            Type::I64 => {
                assert!(
                    value.get_type() == intrinsics.i64_ty.as_basic_type_enum()
                        || value.get_type() == intrinsics.f64_ty.as_basic_type_enum()
                );
                builder.build_bitcast(value, intrinsics.i64_ty, "")
            }
            Type::F64 => {
                assert!(
                    value.get_type() == intrinsics.i64_ty.as_basic_type_enum()
                        || value.get_type() == intrinsics.f64_ty.as_basic_type_enum()
                );
                builder.build_bitcast(value, intrinsics.f64_ty, "")
            }
            Type::V128 => {
                assert!(value.get_type() == intrinsics.i128_ty.as_basic_type_enum());
                value
            }
            Type::ExternRef => unimplemented!("externref in the llvm backend"),
            Type::FuncRef => unimplemented!("funcref in the llvm backend"),
        }
    };

    if let Some(basic_value) = call_site.try_as_basic_value().left() {
        if func_sig.results().len() > 1 {
            if basic_value.get_type() == intrinsics.i64_ty.as_basic_type_enum() {
                assert!(func_sig.results().len() == 2);
                let value = basic_value.into_int_value();
                let (low, high) = split_i64(value);
                let low = casted(low.into(), func_sig.results()[0]);
                let high = casted(high.into(), func_sig.results()[1]);
                return vec![low, high];
            }
            if basic_value.get_type() == f32x2_ty {
                assert!(func_sig.results().len() == 2);
                let (ret0, ret1) = extract_f32x2(basic_value.into_vector_value());
                return vec![ret0.into(), ret1.into()];
            }
            let struct_value = basic_value.into_struct_value();
            let rets = (0..struct_value.get_type().count_fields())
                .map(|i| builder.build_extract_value(struct_value, i, "").unwrap())
                .collect::<Vec<_>>();
            let func_sig_returns_bitwidths = func_sig
                .results()
                .iter()
                .map(|ty| match ty {
                    Type::I32 | Type::F32 => 32,
                    Type::I64 | Type::F64 => 64,
                    Type::V128 => 128,
                    Type::ExternRef => unimplemented!("externref in the llvm backend"),
                    Type::FuncRef => unimplemented!("funcref in the llvm backend"),
                })
                .collect::<Vec<i32>>();

            match func_sig_returns_bitwidths.as_slice() {
                [32, 64] | [64, 32] | [64, 64] => {
                    assert!(func_sig.results().len() == 2);
                    vec![rets[0], rets[1]]
                }
                [32, 32, _]
                    if rets[0].get_type() == intrinsics.f32_ty.vec_type(2).as_basic_type_enum() =>
                {
                    assert!(func_sig.results().len() == 3);
                    let (rets0, rets1) = extract_f32x2(rets[0].into_vector_value());
                    vec![rets0.into(), rets1.into(), rets[1]]
                }
                [32, 32, _] => {
                    assert!(func_sig.results().len() == 3);
                    let (low, high) = split_i64(rets[0].into_int_value());
                    let low = casted(low.into(), func_sig.results()[0]);
                    let high = casted(high.into(), func_sig.results()[1]);
                    vec![low, high, rets[1]]
                }
                [64, 32, 32]
                    if rets[1].get_type() == intrinsics.f32_ty.vec_type(2).as_basic_type_enum() =>
                {
                    assert!(func_sig.results().len() == 3);
                    let (rets1, rets2) = extract_f32x2(rets[1].into_vector_value());
                    vec![rets[0], rets1.into(), rets2.into()]
                }
                [64, 32, 32] => {
                    assert!(func_sig.results().len() == 3);
                    let (rets1, rets2) = split_i64(rets[1].into_int_value());
                    let rets1 = casted(rets1.into(), func_sig.results()[1]);
                    let rets2 = casted(rets2.into(), func_sig.results()[2]);
                    vec![rets[0], rets1, rets2]
                }
                [32, 32, 32, 32] => {
                    assert!(func_sig.results().len() == 4);
                    let (low0, high0) = if rets[0].get_type()
                        == intrinsics.f32_ty.vec_type(2).as_basic_type_enum()
                    {
                        let (x, y) = extract_f32x2(rets[0].into_vector_value());
                        (x.into(), y.into())
                    } else {
                        let (x, y) = split_i64(rets[0].into_int_value());
                        (x.into(), y.into())
                    };
                    let (low1, high1) = if rets[1].get_type()
                        == intrinsics.f32_ty.vec_type(2).as_basic_type_enum()
                    {
                        let (x, y) = extract_f32x2(rets[1].into_vector_value());
                        (x.into(), y.into())
                    } else {
                        let (x, y) = split_i64(rets[1].into_int_value());
                        (x.into(), y.into())
                    };
                    let low0 = casted(low0, func_sig.results()[0]);
                    let high0 = casted(high0, func_sig.results()[1]);
                    let low1 = casted(low1, func_sig.results()[2]);
                    let high1 = casted(high1, func_sig.results()[3]);
                    vec![low0, high0, low1, high1]
                }
                _ => unreachable!("expected an sret for this type"),
            }
        } else {
            assert!(func_sig.results().len() == 1);
            vec![basic_value]
        }
    } else {
        assert!(call_site.count_arguments() > 0); // Either sret or vmctx.
        if call_site
            .get_enum_attribute(
                AttributeLoc::Param(0),
                Attribute::get_named_enum_kind_id("sret"),
            )
            .is_some()
        {
            let sret = call_site
                .try_as_basic_value()
                .right()
                .unwrap()
                .get_operand(0)
                .unwrap()
                .left()
                .unwrap()
                .into_pointer_value();
            let struct_value = builder.build_load(sret, "").into_struct_value();
            let mut rets: Vec<_> = Vec::new();
            for i in 0..struct_value.get_type().count_fields() {
                let value = builder.build_extract_value(struct_value, i, "").unwrap();
                rets.push(value);
            }
            assert!(func_sig.results().len() == rets.len());
            rets
        } else {
            assert!(func_sig.results().is_empty());
            vec![]
        }
    }
}

pub fn is_sret(func_sig: &FuncSig) -> Result<bool, CompileError> {
    let func_sig_returns_bitwidths = func_sig
        .results()
        .iter()
        .map(|ty| match ty {
            Type::I32 | Type::F32 => Ok(32),
            Type::I64 | Type::F64 => Ok(64),
            Type::V128 => Ok(128),
            ty => Err(CompileError::Codegen(format!(
                "is_sret: unimplemented wasmer_types type {:?}",
                ty
            ))),
        })
        .collect::<Result<Vec<i32>, _>>()?;

    Ok(match func_sig_returns_bitwidths.as_slice() {
        []
        | [_]
        | [32, 64]
        | [64, 32]
        | [64, 64]
        | [32, 32]
        | [32, 32, 32]
        | [32, 32, 64]
        | [64, 32, 32]
        | [32, 32, 32, 32] => false,
        _ => true,
    })
}

pub fn pack_values_for_register_return<'ctx>(
    intrinsics: &Intrinsics<'ctx>,
    builder: &Builder<'ctx>,
    values: &[BasicValueEnum<'ctx>],
    func_type: &FunctionType<'ctx>,
) -> Result<BasicValueEnum<'ctx>, CompileError> {
    let is_32 = |value: BasicValueEnum| {
        (value.is_int_value() && value.into_int_value().get_type() == intrinsics.i32_ty)
            || (value.is_float_value() && value.into_float_value().get_type() == intrinsics.f32_ty)
    };
    let is_64 = |value: BasicValueEnum| {
        (value.is_int_value() && value.into_int_value().get_type() == intrinsics.i64_ty)
            || (value.is_float_value() && value.into_float_value().get_type() == intrinsics.f64_ty)
    };
    let is_f32 = |value: BasicValueEnum| {
        value.is_float_value() && value.into_float_value().get_type() == intrinsics.f32_ty
    };

    let pack_i32s = |low: BasicValueEnum<'ctx>, high: BasicValueEnum<'ctx>| {
        assert!(low.get_type() == intrinsics.i32_ty.as_basic_type_enum());
        assert!(high.get_type() == intrinsics.i32_ty.as_basic_type_enum());
        let (low, high) = (low.into_int_value(), high.into_int_value());
        let low = builder.build_int_z_extend(low, intrinsics.i64_ty, "");
        let high = builder.build_int_z_extend(high, intrinsics.i64_ty, "");
        let high = builder.build_left_shift(high, intrinsics.i64_ty.const_int(32, false), "");
        builder.build_or(low, high, "").as_basic_value_enum()
    };

    let pack_f32s = |first: BasicValueEnum<'ctx>,
                     second: BasicValueEnum<'ctx>|
     -> BasicValueEnum<'ctx> {
        assert!(first.get_type() == intrinsics.f32_ty.as_basic_type_enum());
        assert!(second.get_type() == intrinsics.f32_ty.as_basic_type_enum());
        let (first, second) = (first.into_float_value(), second.into_float_value());
        let vec_ty = intrinsics.f32_ty.vec_type(2);
        let vec = builder.build_insert_element(vec_ty.get_undef(), first, intrinsics.i32_zero, "");
        builder
            .build_insert_element(vec, second, intrinsics.i32_ty.const_int(1, false), "")
            .as_basic_value_enum()
    };

    let build_struct = |ty: StructType<'ctx>, values: &[BasicValueEnum<'ctx>]| {
        let mut struct_value = ty.get_undef();
        for (i, v) in values.iter().enumerate() {
            struct_value = builder
                .build_insert_value(struct_value, *v, i as u32, "")
                .unwrap()
                .into_struct_value();
        }
        struct_value.as_basic_value_enum()
    };

    Ok(match *values {
        [one_value] => one_value,
        [v1, v2] if is_f32(v1) && is_f32(v2) => pack_f32s(v1, v2),
        [v1, v2] if is_32(v1) && is_32(v2) => {
            let v1 = builder.build_bitcast(v1, intrinsics.i32_ty, "");
            let v2 = builder.build_bitcast(v2, intrinsics.i32_ty, "");
            pack_i32s(v1, v2)
        }
        [v1, v2] => {
            assert!(!(is_32(v1) && is_32(v2)));
            build_struct(
                func_type.get_return_type().unwrap().into_struct_type(),
                &[v1, v2],
            )
        }
        [v1, v2, v3] if is_f32(v1) && is_f32(v2) => build_struct(
            func_type.get_return_type().unwrap().into_struct_type(),
            &[pack_f32s(v1, v2), v3],
        ),
        [v1, v2, v3] if is_32(v1) && is_32(v2) => {
            let v1 = builder.build_bitcast(v1, intrinsics.i32_ty, "");
            let v2 = builder.build_bitcast(v2, intrinsics.i32_ty, "");
            build_struct(
                func_type.get_return_type().unwrap().into_struct_type(),
                &[pack_i32s(v1, v2), v3],
            )
        }
        [v1, v2, v3] if is_64(v1) && is_f32(v2) && is_f32(v3) => build_struct(
            func_type.get_return_type().unwrap().into_struct_type(),
            &[v1, pack_f32s(v2, v3)],
        ),
        [v1, v2, v3] if is_64(v1) && is_32(v2) && is_32(v3) => {
            let v2 = builder.build_bitcast(v2, intrinsics.i32_ty, "");
            let v3 = builder.build_bitcast(v3, intrinsics.i32_ty, "");
            build_struct(
                func_type.get_return_type().unwrap().into_struct_type(),
                &[v1, pack_i32s(v2, v3)],
            )
        }
        [v1, v2, v3, v4] if is_32(v1) && is_32(v2) && is_32(v3) && is_32(v4) => {
            let v1v2_pack = if is_f32(v1) && is_f32(v2) {
                pack_f32s(v1, v2)
            } else {
                let v1 = builder.build_bitcast(v1, intrinsics.i32_ty, "");
                let v2 = builder.build_bitcast(v2, intrinsics.i32_ty, "");
                pack_i32s(v1, v2)
            };
            let v3v4_pack = if is_f32(v3) && is_f32(v4) {
                pack_f32s(v3, v4)
            } else {
                let v3 = builder.build_bitcast(v3, intrinsics.i32_ty, "");
                let v4 = builder.build_bitcast(v4, intrinsics.i32_ty, "");
                pack_i32s(v3, v4)
            };
            build_struct(
                func_type.get_return_type().unwrap().into_struct_type(),
                &[v1v2_pack, v3v4_pack],
            )
        }
        _ => unreachable!("called to perform register return on struct return or void function"),
    })
}
